Dirichlet–Multinomial Distribution (dirichlet_multinomial)

Dirichlet–Multinomial Distribution (dirichlet_multinomial)#

The Dirichlet–multinomial (a.k.a. Dirichlet compound multinomial) is a discrete multivariate distribution over count vectors. It appears when you model category probabilities as uncertain: draw probabilities \(p\) from a Dirichlet distribution, then draw counts \(X\) from a multinomial given \(p\).

Learning goals#

By the end you should be able to:

  • explain the Dirichlet–multinomial as a “multinomial with random probabilities” and why it captures overdispersion

  • write down the PMF and understand its constraints (support + parameter space)

  • derive the mean and covariance from the hierarchical model

  • sample from it in NumPy and visualize it (1D and simplex plots)

  • use SciPy’s scipy.stats.dirichlet_multinomial for PMF/moments, and implement missing pieces (CDF/sampling/fit) yourself

Prerequisites#

  • basic probability (expectation, variance, conditional expectation)

  • familiarity with the multinomial and Dirichlet distributions

  • comfort reading Gamma/Beta functions

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio

from scipy import stats
from scipy.optimize import minimize
from scipy.special import digamma, gammaln, logsumexp

pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

rng = np.random.default_rng(7)
np.set_printoptions(precision=4, suppress=True)


## 1) Title & Classification

        - **Name**: `dirichlet_multinomial` (Dirichlet–multinomial, Dirichlet compound multinomial)
        - **Type**: **Discrete** (multivariate counts)
        - **Support** (for $K$ categories and total count $n$):

          $$
          \mathcal{S}_{n,K} = \left\{x \in \{0,1,2,\dots\}^K : \sum_{i=1}^K x_i = n\right\}
          $$

        - **Parameter space**:
          - $n \in \{0,1,2,\dots\}$ (integer total count)
          - $\alpha = (\alpha_1,\dots,\alpha_K)$ with $\alpha_i > 0$
          - define $\alpha_0 = \sum_{i=1}^K \alpha_i$ (total concentration)

        A draw $X \sim \text{DirichletMultinomial}(n,\alpha)$ is a **count vector** with a fixed total: $\sum_i X_i = n$.


## 2) Intuition & Motivation

        ### What it models
        The Dirichlet–multinomial models **counts across categories** when the category probabilities themselves vary across trials/replicates.

        A common hierarchical story is:

        $$
        p \sim \mathrm{Dirichlet}(\alpha),
        \qquad
        X \mid p \sim \mathrm{Multinomial}(n, p).
        $$

        If $p$ were fixed, you’d just have a multinomial. But if $p$ changes (e.g., each document has a different word distribution), the multinomial is often **too confident**.
        The Dirichlet–multinomial captures this extra variability (“**overdispersion**”) and induces **negative correlations** between counts because they must sum to $n$.

        ### Typical real-world use cases
        - **Text / NLP**: bag-of-words counts (posterior predictive of a Dirichlet–multinomial model)
        - **Ecology**: species counts across sites with heterogeneous composition
        - **Genomics**: overdispersed categorical counts (e.g., allelic counts)
        - **A/B testing on categories**: uncertainty in category probabilities across cohorts

        ### Relations to other distributions
        - **Dirichlet + Multinomial**: it is the **Dirichlet mixture** of multinomials.
        - **Beta–binomial**: when $K=2$, the first component $X_1$ is Beta–binomial.
        - **Multinomial limit**: as $\alpha_0 \to \infty$ with $\alpha/\alpha_0$ fixed, the Dirichlet–multinomial approaches a multinomial with fixed probabilities.
        - **Pólya urn**: an equivalent sampling scheme is “reinforcement” sampling where each draw increases the chance of drawing that category again.


## 3) Formal Definition

        Let $X=(X_1,\dots,X_K)$ be a count vector with $\sum_i X_i = n$.

        ### PMF (discrete analogue of a PDF)
        For $x \in \mathcal{S}_{n,K}$:

        $$
        \Pr(X=x \mid n,\alpha)
        = \frac{n!}{\prod_{i=1}^K x_i!}
          \frac{\Gamma(\alpha_0)}{\Gamma(\alpha_0+n)}
          \prod_{i=1}^K \frac{\Gamma(\alpha_i + x_i)}{\Gamma(\alpha_i)}.
        $$

        Using the rising factorial (Pochhammer symbol) $(a)_m = \Gamma(a+m)/\Gamma(a)$, this can be written:

        $$
        \Pr(X=x \mid n,\alpha)
        = \frac{n!}{\prod_i x_i!}
          \frac{\prod_i (\alpha_i)_{x_i}}{(\alpha_0)_n}.
        $$

        ### CDF
        A common multivariate “CDF” is the **lower-orthant CDF**:

        $$
        F(x) = \Pr(X_1 \le x_1,\dots,X_K \le x_K)
        = \sum_{y \in \mathcal{S}_{n,K}:\; y_i \le x_i\;\forall i} \Pr(X=y).
        $$

        There is no simple closed form in general. For $K=2$, this reduces to the usual **univariate CDF** of the Beta–binomial distribution.
def _validate_alpha(alpha) -> np.ndarray:
    alpha = np.asarray(alpha, dtype=float)
    if alpha.ndim != 1:
        raise ValueError("alpha must be a 1D array of positive values")
    if alpha.size < 2:
        raise ValueError("alpha must have length K>=2")
    if not np.all(np.isfinite(alpha)):
        raise ValueError("alpha must be finite")
    if np.any(alpha <= 0):
        raise ValueError("alpha must be strictly positive")
    return alpha


def _validate_counts(x, k: int) -> np.ndarray:
    x = np.asarray(x)
    if x.ndim == 1:
        x = x[None, :]
    if x.ndim != 2 or x.shape[1] != k:
        raise ValueError(f"x must have shape (k,) or (m,k) with k={k}")

    if not np.issubdtype(x.dtype, np.integer):
        if np.any(np.abs(x - np.round(x)) > 0):
            raise ValueError("x must contain integers")
        x = np.round(x).astype(int)
    else:
        x = x.astype(int)

    if np.any(x < 0):
        raise ValueError("x must be nonnegative")

    return x


def dirichlet_multinomial_logpmf(x, alpha, n: int | None = None) -> np.ndarray:
    '''Vectorized log PMF for the Dirichlet–multinomial.

    Parameters
    ----------
    x:
        Count vector(s), shape (k,) or (m,k). Each row must sum to n.
    alpha:
        Concentration parameters (k,), alpha_i > 0.
    n:
        Total count. If None, inferred from x row sums.
    '''
    alpha = _validate_alpha(alpha)
    x = _validate_counts(x, k=alpha.size)

    row_sums = x.sum(axis=1)
    if n is None:
        n_vec = row_sums
    else:
        if np.any(row_sums != n):
            raise ValueError("Each row of x must sum to n")
        n_vec = np.full_like(row_sums, fill_value=n)

    alpha0 = alpha.sum()

    log_multinomial_coeff = gammaln(n_vec + 1) - np.sum(gammaln(x + 1), axis=1)
    log_norm = gammaln(alpha0) - gammaln(alpha0 + n_vec)
    log_ratio = np.sum(gammaln(alpha + x) - gammaln(alpha), axis=1)

    out = log_multinomial_coeff + log_norm + log_ratio
    return out[0] if out.size == 1 else out


def dirichlet_multinomial_pmf(x, alpha, n: int | None = None) -> np.ndarray:
    return np.exp(dirichlet_multinomial_logpmf(x, alpha=alpha, n=n))


def compositions(n: int, k: int):
    '''Generate all k-tuples of nonnegative integers summing to n (stars and bars).'''
    if k == 1:
        yield (n,)
        return
    for i in range(n + 1):
        for tail in compositions(n - i, k - 1):
            yield (i,) + tail


def enumerate_support(n: int, k: int) -> np.ndarray:
    '''Enumerate the support S_{n,k}. Size is comb(n+k-1, k-1).'''
    return np.array(list(compositions(n, k)), dtype=int)


def dm_cdf_small_n(x, alpha, n: int) -> float:
    '''Lower-orthant CDF by brute-force summation (only feasible for small n,k).'''
    alpha = _validate_alpha(alpha)
    x = _validate_counts(x, k=alpha.size)[0]
    if x.sum() != n:
        raise ValueError("x must sum to n")

    ys = enumerate_support(n=n, k=alpha.size)
    mask = np.all(ys <= x[None, :], axis=1)
    return float(np.sum(dirichlet_multinomial_pmf(ys[mask], alpha=alpha, n=n)))


def simplex_xy_3(counts: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    '''Map 3-category compositions to 2D barycentric coordinates for plotting.'''
    counts = np.asarray(counts, dtype=float)
    counts = np.atleast_2d(counts)
    if counts.shape[1] != 3:
        raise ValueError("simplex_xy_3 expects shape (m,3)")

    n = counts.sum(axis=1)
    if np.any(n <= 0):
        raise ValueError("All rows must sum to a positive n")

    p = counts / n[:, None]
    x = p[:, 1] + 0.5 * p[:, 2]
    y = (np.sqrt(3) / 2.0) * p[:, 2]
    return x, y


def dirichlet_rvs_numpy(alpha, size: int, rng: np.random.Generator) -> np.ndarray:
    '''Sample Dirichlet(alpha) via Gamma normalization (NumPy-only).'''
    alpha = _validate_alpha(alpha)
    g = rng.gamma(shape=alpha, scale=1.0, size=(size, alpha.size))
    return g / g.sum(axis=1, keepdims=True)


def dirichlet_multinomial_rvs_numpy(alpha, n: int, size: int, rng: np.random.Generator) -> np.ndarray:
    '''Sample Dirichlet–multinomial(n, alpha) (NumPy-only).

    Algorithm:
    1) p ~ Dirichlet(alpha)
    2) X | p ~ Multinomial(n, p)
    '''
    alpha = _validate_alpha(alpha)
    ps = dirichlet_rvs_numpy(alpha, size=size, rng=rng)
    out = np.empty((size, alpha.size), dtype=int)
    for i, p in enumerate(ps):
        out[i] = rng.multinomial(n, p)
    return out
# Quick sanity check against SciPy's PMF
alpha = np.array([1.0, 2.0, 3.0])
n = 10
x = np.array([2, 3, 5])

pmf_numpy = dirichlet_multinomial_pmf(x, alpha=alpha, n=n)
pmf_scipy = stats.dirichlet_multinomial.pmf(x, alpha=alpha, n=n)

pmf_numpy, pmf_scipy, float(pmf_numpy - pmf_scipy)
(0.027972027972027885, 0.02797202797202796, -7.632783294297951e-17)


## 4) Moments & Properties

        ### Mean
        Using the hierarchical model $p \sim \mathrm{Dirichlet}(\alpha)$ and $X\mid p \sim \mathrm{Multinomial}(n,p)$,

        $$
        \mathbb{E}[X_i] = n\,\frac{\alpha_i}{\alpha_0}.
        $$

        ### Covariance

        For $i \ne j$:

        $$
        \mathrm{Cov}(X_i, X_j)
        = -\,n\,\frac{\alpha_i\alpha_j}{\alpha_0^2}\,\frac{n+\alpha_0}{\alpha_0+1}.
        $$

        For the variance (the $i=j$ case):

        $$
        \mathrm{Var}(X_i)
        = n\,\frac{\alpha_i}{\alpha_0}\left(1-\frac{\alpha_i}{\alpha_0}\right)\,\frac{n+\alpha_0}{\alpha_0+1}.
        $$

        ### Marginals (Beta–binomial)
        Each component $X_i$ marginally follows a Beta–binomial distribution:

        $$
        X_i \sim \mathrm{BetaBinomial}\big(n,\; \alpha_i,\; \alpha_0-\alpha_i\big).
        $$

        This is useful because it gives you **univariate** quantities like skewness and kurtosis for each component.

        A clean way to get higher moments is via **factorial moments**. For $r\in\{1,2,3,\dots\}$:

        $$
        \mathbb{E}[(X_i)_{r}] = (n)_{r}\,\frac{(\alpha_i)_{r}}{(\alpha_0)_{r}},
        $$

        where $(a)_r$ is the rising factorial and $(X_i)_r$ on the left denotes the *falling* factorial.
        From these you can reconstruct raw/central moments (and thus skewness/kurtosis).

        ### MGF / characteristic function
        For a vector $t\in\mathbb{R}^K$, the MGF can be written as a (finite) sum over the support:

        $$
        M_X(t) = \mathbb{E}[e^{t^\top X}] = \sum_{x\in\mathcal{S}_{n,K}} e^{t^\top x}\,\Pr(X=x).
        $$

        Equivalently, via the mixture:

        $$
        M_X(t) = \mathbb{E}_{p\sim\mathrm{Dir}(\alpha)}\left[\left(\sum_{i=1}^K p_i e^{t_i}\right)^n\right].
        $$

        Closed forms involve special functions (multivariate hypergeometric functions). For small $n$ you can compute it by enumeration.
        The characteristic function is $\varphi(\omega)=M_X(i\omega)$.

        ### Entropy
        The Shannon entropy is

        $$
        H(X) = -\sum_{x\in\mathcal{S}_{n,K}} \Pr(X=x)\,\log \Pr(X=x).
        $$

        There is no simple closed form in general; you can compute it exactly by enumeration for small $n$, or estimate it by Monte Carlo.
def dm_mean(alpha, n: int) -> np.ndarray:
    alpha = _validate_alpha(alpha)
    return n * alpha / alpha.sum()


def dm_cov(alpha, n: int) -> np.ndarray:
    alpha = _validate_alpha(alpha)
    k = alpha.size
    alpha0 = alpha.sum()

    # Cov(X_i, X_j) = -n * alpha_i*alpha_j / alpha0^2 * (n+alpha0)/(alpha0+1)
    factor = n * (n + alpha0) / (alpha0 + 1.0) / (alpha0**2)
    cov = -factor * np.outer(alpha, alpha)

    # Fix diagonal to variance formula
    p = alpha / alpha0
    var = n * p * (1.0 - p) * (n + alpha0) / (alpha0 + 1.0)
    np.fill_diagonal(cov, var)
    return cov


def beta_binomial_moments_via_factorials(n: int, a: float, b: float):
    '''Return (mean, variance, skewness, excess_kurtosis) for BetaBinomial(n,a,b).

    Uses factorial moments (stable + avoids a giant closed-form expression).
    '''
    a0 = a + b

    # Falling factorial moments of X: E[(X)_r] = (n)_r E[p^r]
    # with p~Beta(a,b), E[p^r] = (a)_r/(a0)_r (rising factorial).
    f1 = n * a / a0
    f2 = n * (n - 1) * a * (a + 1) / (a0 * (a0 + 1)) if n >= 2 else 0.0
    f3 = (
        n * (n - 1) * (n - 2) * a * (a + 1) * (a + 2) / (a0 * (a0 + 1) * (a0 + 2))
        if n >= 3
        else 0.0
    )
    f4 = (
        n
        * (n - 1)
        * (n - 2)
        * (n - 3)
        * a
        * (a + 1)
        * (a + 2)
        * (a + 3)
        / (a0 * (a0 + 1) * (a0 + 2) * (a0 + 3))
        if n >= 4
        else 0.0
    )

    # Stirling-number conversion (X^r = sum_k S(r,k) (X)_k)
    m1 = f1
    m2 = f1 + f2
    m3 = f1 + 3 * f2 + f3
    m4 = f1 + 7 * f2 + 6 * f3 + f4

    mu = m1
    mu2 = m2 - mu**2
    mu3 = m3 - 3 * m2 * mu + 2 * mu**3
    mu4 = m4 - 4 * m3 * mu + 6 * m2 * mu**2 - 3 * mu**4

    skew = mu3 / (mu2 ** 1.5) if mu2 > 0 else np.nan
    kurt_excess = mu4 / (mu2**2) - 3.0 if mu2 > 0 else np.nan
    return mu, mu2, skew, kurt_excess


def dm_entropy_small_n(alpha, n: int) -> float:
    alpha = _validate_alpha(alpha)
    xs = enumerate_support(n=n, k=alpha.size)
    logp = dirichlet_multinomial_logpmf(xs, alpha=alpha, n=n)
    p = np.exp(logp)
    return float(-np.sum(p * logp))


def dm_mgf_small_n(t, alpha, n: int) -> float:
    alpha = _validate_alpha(alpha)
    t = np.asarray(t, dtype=float)
    if t.shape != alpha.shape:
        raise ValueError(f"t must have shape {alpha.shape}")

    xs = enumerate_support(n=n, k=alpha.size)
    logp = dirichlet_multinomial_logpmf(xs, alpha=alpha, n=n)
    return float(np.exp(logsumexp(logp + xs @ t)))


# Example: moments for a 3-category model
alpha = np.array([1.5, 2.0, 4.5])
n = 20

mean = dm_mean(alpha, n=n)
cov = dm_cov(alpha, n=n)
ent = dm_entropy_small_n(alpha, n=n)

mean, cov, ent
(array([ 3.75,  5.  , 11.25]),
 array([[ 9.4792, -2.9167, -6.5625],
        [-2.9167, 11.6667, -8.75  ],
        [-6.5625, -8.75  , 15.3125]]),
 4.885255849427407)
# Marginal skewness/kurtosis via the Beta–binomial identity
alpha0 = alpha.sum()
for i in range(alpha.size):
    a = alpha[i]
    b = alpha0 - alpha[i]
    m, v, s, kex = beta_binomial_moments_via_factorials(n=n, a=float(a), b=float(b))
    print(
        f"X_{i+1}: mean={m:.3f}, var={v:.3f}, skew={s:.3f}, excess_kurtosis={kex:.3f} "
        f"(BetaBinomial(n={n}, a={a:.2f}, b={b:.2f}))"
    )

# Cross-check against SciPy's betabinom for one component
i = 0
a = float(alpha[i])
b = float(alpha0 - alpha[i])
scipy_mean, scipy_var, scipy_skew, scipy_kex = stats.betabinom.stats(n=n, a=a, b=b, moments="mvsk")
(scipy_mean, scipy_var, scipy_skew, scipy_kex)
X_1: mean=3.750, var=9.479, skew=0.974, excess_kurtosis=0.711 (BetaBinomial(n=20, a=1.50, b=6.50))
X_2: mean=5.000, var=11.667, skew=0.703, excess_kurtosis=0.097 (BetaBinomial(n=20, a=2.00, b=6.00))
X_3: mean=11.250, var=15.312, skew=-0.153, excess_kurtosis=-0.537 (BetaBinomial(n=20, a=4.50, b=3.50))
(3.75, 9.479166666666666, 0.9743975315293802, 0.7108891108891111)


## 5) Parameter Interpretation

        Think of $\alpha$ as **prior pseudo-counts** for category probabilities.

        - The **mean proportions** are

          $$\mathbb{E}[p_i] = \frac{\alpha_i}{\alpha_0},\quad \alpha_0=\sum_i\alpha_i.$$

        - The **total concentration** $\alpha_0$ controls how much $p$ varies across replicates:

          - small $\alpha_0$  → $p$ is highly variable → counts are **more dispersed** (more mass near simplex corners)
          - large $\alpha_0$  → $p$ concentrates near its mean → counts look more like a plain multinomial

        - Holding $\alpha_0$ fixed, changing the **ratios** $\alpha_i/\alpha_0$ shifts mass toward categories with larger ratios.

        Below we visualize samples for the same mean proportions but different total concentration.
# Same mean proportions, different concentration alpha0
n = 25
base = np.array([1.0, 2.0, 3.0])
base = base / base.sum()  # mean proportions

scales = [0.3, 1.0, 5.0]
size = 2500

fig = go.Figure()

# draw simplex triangle
tri_x = [0.0, 1.0, 0.5, 0.0]
tri_y = [0.0, 0.0, np.sqrt(3) / 2.0, 0.0]
fig.add_trace(
    go.Scatter(x=tri_x, y=tri_y, mode="lines", line=dict(color="black"), showlegend=False)
)

for s in scales:
    alpha_s = s * base * 30.0  # scale into a reasonable pseudo-count regime
    samples = dirichlet_multinomial_rvs_numpy(alpha=alpha_s, n=n, size=size, rng=rng)
    x, y = simplex_xy_3(samples)
    fig.add_trace(
        go.Scattergl(
            x=x,
            y=y,
            mode="markers",
            name=f"alpha0≈{alpha_s.sum():.1f}",
            marker=dict(size=4, opacity=0.25),
        )
    )

fig.update_layout(
    title="Dirichlet–multinomial samples on the 3-simplex (same mean, different concentration)",
    xaxis_title="barycentric x",
    yaxis_title="barycentric y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    width=850,
    height=500,
)
fig.show()


## 6) Derivations

        We derive mean and covariance using the mixture representation:

        $$
        p \sim \mathrm{Dirichlet}(\alpha),
        \qquad
        X \mid p \sim \mathrm{Multinomial}(n, p).
        $$

        ### Expectation
        By the law of total expectation:

        $$
        \mathbb{E}[X_i] = \mathbb{E}\big[\,\mathbb{E}[X_i\mid p]\,\big]
        = \mathbb{E}[n p_i]
        = n\,\mathbb{E}[p_i]
        = n\,\frac{\alpha_i}{\alpha_0}.
        $$

        ### Variance
        By the law of total variance:

        $$
        \mathrm{Var}(X_i) = \mathbb{E}[\mathrm{Var}(X_i\mid p)] + \mathrm{Var}(\mathbb{E}[X_i\mid p]).
        $$

        For a multinomial:

        $$
        \mathbb{E}[X_i\mid p] = n p_i,
        \qquad
        \mathrm{Var}(X_i\mid p) = n p_i(1-p_i).
        $$

        So:

        $$
        \mathrm{Var}(X_i)
        = \mathbb{E}[n p_i(1-p_i)] + \mathrm{Var}(n p_i)
        = n\,\mathbb{E}[p_i - p_i^2] + n^2\,\mathrm{Var}(p_i).
        $$

        Using Dirichlet moments
        $\mathbb{E}[p_i]=\alpha_i/\alpha_0$ and $\mathrm{Var}(p_i)=\alpha_i(\alpha_0-\alpha_i)/(\alpha_0^2(\alpha_0+1))$
        yields the variance formula in Section 4.

        ### Covariance
        Similarly, for $i\ne j$:

        $$
        \mathrm{Cov}(X_i, X_j)
        = \mathbb{E}[\mathrm{Cov}(X_i,X_j\mid p)] + \mathrm{Cov}(\mathbb{E}[X_i\mid p],\mathbb{E}[X_j\mid p]).
        $$

        For a multinomial, $\mathrm{Cov}(X_i,X_j\mid p) = -n p_i p_j$ for $i\ne j$.
        With Dirichlet moments for $\mathbb{E}[p_i p_j]$, you arrive at the negative covariance formula.

        ### Likelihood (for fitting $\alpha$)
        Given an observed count vector $x$ with total $n$, the likelihood as a function of $\alpha$ is:

        $$
        L(\alpha; x)
        \propto
        \frac{\Gamma(\alpha_0)}{\Gamma(\alpha_0+n)}
        \prod_{i=1}^K \frac{\Gamma(\alpha_i + x_i)}{\Gamma(\alpha_i)}.
        $$

        Taking logs gives:

        $$
        \ell(\alpha; x)
        = \log\Gamma(\alpha_0) - \log\Gamma(\alpha_0+n)
          + \sum_i \big(\log\Gamma(\alpha_i + x_i) - \log\Gamma(\alpha_i)\big)
          + \text{const}(x).
        $$

        There is no closed-form MLE for $\alpha$ in general; you typically optimize $\ell(\alpha)$ numerically.


## 7) Sampling & Simulation

        A simple **NumPy-only** sampling algorithm follows directly from the hierarchical story:

        1. Sample $p \sim \mathrm{Dirichlet}(\alpha)$.
           A standard implementation uses Gamma variables: draw $g_i \sim \mathrm{Gamma}(\alpha_i, 1)$ and set $p_i = g_i / \sum_j g_j$.
        2. Sample counts $X \mid p \sim \mathrm{Multinomial}(n, p)$.

        This is exactly what `dirichlet_multinomial_rvs_numpy` implements.

        Below we verify mean/covariance by Monte Carlo.
alpha = np.array([1.5, 2.0, 4.5])
            n = 25

            theory_mean = dm_mean(alpha, n=n)
            theory_cov = dm_cov(alpha, n=n)

            samples = dirichlet_multinomial_rvs_numpy(alpha=alpha, n=n, size=50_000, rng=rng)
            sample_mean = samples.mean(axis=0)
            sample_cov = np.cov(samples.T, ddof=0)

            print('theory mean:', theory_mean)
            print('sample mean:', sample_mean)
            print('
max abs mean error:', np.max(np.abs(sample_mean - theory_mean)))

            print('
max abs cov error:', np.max(np.abs(sample_cov - theory_cov)))
  Cell In[7], line 2
    n = 25
    ^
IndentationError: unexpected indent


## 8) Visualization

        Because the Dirichlet–multinomial is multivariate, visuals depend on $K$:

        - For $K=2$ it reduces to a **Beta–binomial** and you can plot a standard PMF/CDF over $\{0,1,\dots,n\}$.
        - For $K=3$ you can plot probabilities/samples on a 2D simplex (triangle).

        We do both.
# K=2: PMF and CDF (Beta–binomial view)
n = 30
alpha2 = np.array([2.0, 5.0])

xs = np.arange(n + 1)
pmf = np.array([dirichlet_multinomial_pmf([x, n - x], alpha=alpha2, n=n) for x in xs])
cdf = np.cumsum(pmf)

fig = go.Figure()
fig.add_trace(go.Bar(x=xs, y=pmf, name="PMF"))
fig.update_layout(
    title="Dirichlet–multinomial with K=2 (PMF of X1)",
    xaxis_title="x",
    yaxis_title="P(X1=x)",
    bargap=0.05,
    width=850,
    height=380,
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Scatter(x=xs, y=cdf, mode="lines+markers", name="CDF"))
fig.update_layout(
    title="Dirichlet–multinomial with K=2 (CDF of X1)",
    xaxis_title="x",
    yaxis_title="P(X1≤x)",
    width=850,
    height=380,
)
fig.show()
# Monte Carlo samples vs PMF (K=2)
size = 20_000
s = dirichlet_multinomial_rvs_numpy(alpha=alpha2, n=n, size=size, rng=rng)
x1 = s[:, 0]

fig = px.histogram(x1, nbins=n + 1, histnorm="probability", title="Monte Carlo histogram vs PMF")
fig.add_trace(go.Scatter(x=xs, y=pmf, mode="lines", name="PMF", line=dict(color="black")))
fig.update_layout(xaxis_title="x1", yaxis_title="probability", width=850, height=420)
fig.show()
# K=3: PMF on the simplex (small n, exact enumeration)
n = 20
alpha3 = np.array([1.2, 2.5, 4.0])

support = enumerate_support(n=n, k=3)
logp = dirichlet_multinomial_logpmf(support, alpha=alpha3, n=n)
p = np.exp(logp)

sx, sy = simplex_xy_3(support)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=sx,
        y=sy,
        mode="markers",
        marker=dict(
            size=10,
            color=np.log10(p),
            colorscale="Viridis",
            colorbar=dict(title="log10 PMF"),
        ),
        text=[str(tuple(row)) for row in support],
        hovertemplate="x=%{text}<br>log10 p=%{marker.color:.3f}<extra></extra>",
        name="support",
    )
)

tri_x = [0.0, 1.0, 0.5, 0.0]
tri_y = [0.0, 0.0, np.sqrt(3) / 2.0, 0.0]
fig.add_trace(go.Scatter(x=tri_x, y=tri_y, mode="lines", line=dict(color="black"), showlegend=False))

fig.update_layout(
    title="Dirichlet–multinomial PMF on the 3-simplex (exact, enumerated support)",
    xaxis_title="barycentric x",
    yaxis_title="barycentric y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    width=850,
    height=520,
)
fig.show()
# K=3: Monte Carlo samples on the simplex
n = 20
alpha3 = np.array([1.2, 2.5, 4.0])

samples = dirichlet_multinomial_rvs_numpy(alpha=alpha3, n=n, size=6000, rng=rng)
x, y = simplex_xy_3(samples)

fig = go.Figure()
fig.add_trace(
    go.Scattergl(
        x=x,
        y=y,
        mode="markers",
        marker=dict(size=4, opacity=0.25),
        text=[str(tuple(row)) for row in samples],
        hovertemplate="x=%{text}<extra></extra>",
        name="samples",
    )
)

tri_x = [0.0, 1.0, 0.5, 0.0]
tri_y = [0.0, 0.0, np.sqrt(3) / 2.0, 0.0]
fig.add_trace(go.Scatter(x=tri_x, y=tri_y, mode="lines", line=dict(color="black"), showlegend=False))

fig.update_layout(
    title="Dirichlet–multinomial Monte Carlo samples on the 3-simplex",
    xaxis_title="barycentric x",
    yaxis_title="barycentric y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    width=850,
    height=520,
)
fig.show()


## 9) SciPy Integration

        SciPy exposes the Dirichlet–multinomial as `scipy.stats.dirichlet_multinomial`.

        In this environment (SciPy version may vary), the object provides:
        - `pmf` / `logpmf`
        - moment methods like `mean`, `var`, and `cov`

        But it may **not** provide `cdf`, `rvs`, or `fit` for this multivariate distribution.
        We’ll show:
        - how to use SciPy where available
        - how to implement missing pieces (CDF by summation for small $n$, sampling via the hierarchical model, and MLE via `scipy.optimize`)
import scipy

            print('SciPy version:', scipy.__version__)

            n = 10
            alpha = np.array([1.0, 2.0, 3.0])
            x = np.array([2, 3, 5])

            dm = stats.dirichlet_multinomial(n=n, alpha=alpha)

            print('pmf:', dm.pmf(x))
            print('logpmf:', dm.logpmf(x))
            print('mean:', dm.mean())
            print('cov:
', dm.cov())

            # Feature check
            for name in ['cdf', 'rvs', 'fit']:
                print(name, 'available?', hasattr(dm, name))
  Cell In[12], line 3
    print('SciPy version:', scipy.__version__)
    ^
IndentationError: unexpected indent
# CDF: SciPy doesn't implement a multivariate cdf here, but we can compute it by brute force for small n.
n = 12
alpha = np.array([1.0, 2.0, 3.0])
x = np.array([3, 4, 5])

dm_cdf_small_n(x, alpha=alpha, n=n)
0.016968325791855168
# For K=2, the CDF reduces to the usual univariate Beta–binomial CDF
n = 30
alpha2 = np.array([2.0, 5.0])
xs = np.arange(n + 1)

# X1 ~ BetaBinomial(n, a=alpha1, b=alpha2)
cdf_scipy = stats.betabinom.cdf(xs, n=n, a=alpha2[0], b=alpha2[1])
cdf_numpy = np.cumsum([dirichlet_multinomial_pmf([x, n - x], alpha=alpha2, n=n) for x in xs])

float(np.max(np.abs(cdf_scipy - cdf_numpy)))
2.55351295663786e-15
# Sampling: SciPy's dirichlet_multinomial may not expose rvs, but sampling is easy via the hierarchical model.
# Here is a SciPy-flavored sampler (Dirichlet from SciPy + Multinomial from NumPy):

def dirichlet_multinomial_rvs_scipy(alpha, n: int, size: int, rng: np.random.Generator) -> np.ndarray:
    alpha = _validate_alpha(alpha)
    ps = stats.dirichlet.rvs(alpha, size=size, random_state=rng)
    out = np.empty((size, alpha.size), dtype=int)
    for i, p in enumerate(ps):
        out[i] = rng.multinomial(n, p)
    return out


alpha = np.array([1.0, 2.0, 3.0])
n = 10
samples = dirichlet_multinomial_rvs_scipy(alpha=alpha, n=n, size=5, rng=rng)
samples
array([[1, 1, 8],
       [4, 2, 4],
       [0, 3, 7],
       [0, 5, 5],
       [2, 2, 6]])
# Fit (MLE): optimize the Dirichlet–multinomial log-likelihood for alpha

def dm_loglik(alpha, X: np.ndarray) -> float:
    alpha = _validate_alpha(alpha)
    X = _validate_counts(X, k=alpha.size)
    n_vec = X.sum(axis=1)
    alpha0 = alpha.sum()

    # omit multinomial coefficient terms (constants wrt alpha)
    ll = (
        X.shape[0] * gammaln(alpha0)
        - np.sum(gammaln(alpha0 + n_vec))
        + np.sum(gammaln(alpha + X) - gammaln(alpha), axis=1).sum()
    )
    return float(ll)


def dm_loglik_grad(alpha, X: np.ndarray) -> np.ndarray:
    alpha = _validate_alpha(alpha)
    X = _validate_counts(X, k=alpha.size)
    n_vec = X.sum(axis=1)
    alpha0 = alpha.sum()

    m = X.shape[0]
    common = m * digamma(alpha0) - np.sum(digamma(alpha0 + n_vec))
    grad = common + np.sum(digamma(alpha + X), axis=0) - m * digamma(alpha)
    return grad


def dm_mom_alpha_init(X: np.ndarray) -> np.ndarray:
    '''Method-of-moments-ish initializer for alpha (works best when n is constant).'''
    X = np.asarray(X, dtype=float)
    X = np.atleast_2d(X)

    n_vec = X.sum(axis=1)
    if not np.allclose(n_vec, n_vec[0]):
        # fall back: mean proportions with moderate concentration
        p_hat = X.sum(axis=0) / X.sum()
        return 20.0 * p_hat

    n = float(n_vec[0])
    p_hat = X.mean(axis=0) / n
    s2 = X.var(axis=0, ddof=0)

    # v_i ≈ Var / (n p(1-p)) = (n+alpha0)/(alpha0+1)
    denom = n * p_hat * (1.0 - p_hat)
    usable = denom > 1e-12
    v = np.median((s2[usable] / denom[usable]).clip(min=1.0)) if np.any(usable) else 1.0

    if v <= 1.0 + 1e-8:
        alpha0 = 1e3
    else:
        alpha0 = (n - v) / (v - 1.0)
        alpha0 = float(np.clip(alpha0, 1e-3, 1e4))

    return alpha0 * p_hat


def fit_dirichlet_multinomial_mle(X: np.ndarray, alpha_init: np.ndarray | None = None) -> np.ndarray:
    X = _validate_counts(X, k=np.asarray(X).shape[-1])
    k = X.shape[1]

    if alpha_init is None:
        alpha_init = dm_mom_alpha_init(X)
    alpha_init = np.asarray(alpha_init, dtype=float)
    if alpha_init.shape != (k,):
        raise ValueError(f"alpha_init must have shape ({k},)")

    # optimize over log(alpha) to enforce positivity
    x0 = np.log(alpha_init)

    bounds = [(-10.0, 10.0)] * k  # keeps alpha in a safe numeric range

    def obj(log_alpha):
        a = np.exp(log_alpha)
        return -dm_loglik(a, X)

    def grad(log_alpha):
        a = np.exp(log_alpha)
        return -(dm_loglik_grad(a, X) * a)  # chain rule

    res = minimize(obj, x0=x0, jac=grad, method='L-BFGS-B', bounds=bounds)
    if not res.success:
        raise RuntimeError(f"MLE optimization failed: {res.message}")
    return np.exp(res.x)


# Demo: simulate + fit
rng_fit = np.random.default_rng(0)

alpha_true = np.array([1.2, 2.5, 4.0])
n = 20
m = 300

X = dirichlet_multinomial_rvs_numpy(alpha=alpha_true, n=n, size=m, rng=rng_fit)
alpha_hat = fit_dirichlet_multinomial_mle(X)

alpha_true, alpha_hat, alpha_hat / alpha_hat.sum()
(array([1.2, 2.5, 4. ]),
 array([1.0448, 2.4962, 3.8207]),
 array([0.1419, 0.3391, 0.519 ]))


## 10) Statistical Use Cases

        ### 1) Hypothesis testing: multinomial vs overdispersed counts
        A common question is whether a plain multinomial is *too restrictive*.
        You can compare:

        - **$H_0$**: $X \sim \mathrm{Multinomial}(n, p)$ (fixed $p$)
        - **$H_1$**: $X \sim \mathrm{DirichletMultinomial}(n, \alpha)$ (random $p$)

        A likelihood ratio statistic can be used, but the usual $\chi^2$ reference is unreliable because the multinomial is a boundary case (roughly $\alpha_0\to\infty$).
        A practical approach is a **parametric bootstrap** under $H_0$.

        ### 2) Bayesian modeling: posterior predictive
        If you place a Dirichlet prior on multinomial probabilities, the posterior is Dirichlet and the **posterior predictive** for new counts is Dirichlet–multinomial.

        ### 3) Generative modeling
        Dirichlet–multinomial is a natural “bag-of-words” generator: it samples a document-level word distribution and then generates word counts.
# Bayesian modeling: Dirichlet posterior + Dirichlet–multinomial posterior predictive

alpha_prior = np.array([1.0, 1.0, 1.0])
x_obs = np.array([4, 1, 5])

alpha_post = alpha_prior + x_obs

print('prior mean p:', alpha_prior / alpha_prior.sum())
print('posterior mean p:', alpha_post / alpha_post.sum())

# Posterior predictive for future n_new counts
n_new = 12
x_future = np.array([3, 5, 4])
p_pred = dirichlet_multinomial_pmf(x_future, alpha=alpha_post, n=n_new)
p_pred
prior mean p: [0.3333 0.3333 0.3333]
posterior mean p: [0.3846 0.1538 0.4615]
0.009784938442900492
# Hypothesis testing demo: parametric bootstrap LRT (small example)

def multinomial_loglik(X: np.ndarray, p: np.ndarray) -> float:
    X = _validate_counts(X, k=p.size)
    p = np.asarray(p, dtype=float)
    if p.ndim != 1 or p.size != X.shape[1]:
        raise ValueError('p must be shape (k,)')
    if np.any(p <= 0):
        raise ValueError('p must be strictly positive (use smoothing if needed)')
    p = p / p.sum()

    n_vec = X.sum(axis=1)
    ll = (
        gammaln(n_vec + 1)
        - np.sum(gammaln(X + 1), axis=1)
        + (X * np.log(p)).sum(axis=1)
    ).sum()
    return float(ll)


def lrt_statistic(X: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
    X = _validate_counts(X, k=np.asarray(X).shape[-1])
    n_vec = X.sum(axis=1)
    if not np.all(n_vec == n_vec[0]):
        raise ValueError('This demo assumes constant n across rows')

    # H0: multinomial MLE for p
    p_hat = X.sum(axis=0) / X.sum()
    p_hat = (p_hat + 1e-12) / (p_hat.sum() + 1e-12 * p_hat.size)

    ll0 = multinomial_loglik(X, p_hat)

    # H1: Dirichlet–multinomial MLE for alpha
    alpha_hat = fit_dirichlet_multinomial_mle(X)
    ll1 = dm_loglik(alpha_hat, X)

    return 2.0 * (ll1 - ll0), p_hat, alpha_hat


rng_test = np.random.default_rng(123)

# Simulate an overdispersed dataset under H1
alpha_true = np.array([1.2, 2.5, 4.0])
n = 20
m = 80
X = dirichlet_multinomial_rvs_numpy(alpha=alpha_true, n=n, size=m, rng=rng_test)

lrt_obs, p_hat_obs, alpha_hat_obs = lrt_statistic(X)
print('Observed LRT:', lrt_obs)
print('alpha_hat:', alpha_hat_obs)

# Bootstrap under H0 (multinomial)
B = 30
lrt_boot = []
for _ in range(B):
    Xb = rng_test.multinomial(n, p_hat_obs, size=m)
    stat, _, _ = lrt_statistic(Xb)
    lrt_boot.append(stat)

lrt_boot = np.array(lrt_boot)
p_value = float(np.mean(lrt_boot >= lrt_obs))

print('bootstrap LRT mean:', lrt_boot.mean())
print('bootstrap p-value (rough, small B):', p_value)
Observed LRT: -2054.403206622123
alpha_hat: [1.2211 2.5888 3.5766]
bootstrap LRT mean: -2586.3898742580986
bootstrap p-value (rough, small B): 0.0
# Generative modeling example: "documents" as category-count vectors

alpha_topic = np.array([0.4, 0.4, 0.4])  # sparse-ish p for each document
n_words = 60
n_docs = 200

docs = dirichlet_multinomial_rvs_numpy(alpha=alpha_topic, n=n_words, size=n_docs, rng=rng)

# Visualize document-level proportions
props = docs / docs.sum(axis=1, keepdims=True)
fig = px.scatter_3d(
    x=props[:, 0], y=props[:, 1], z=props[:, 2],
    title="Document-level proportions (Dirichlet–multinomial generator)",
    labels={'x': 'p1', 'y': 'p2', 'z': 'p3'}
)
fig.update_traces(marker=dict(size=3, opacity=0.6))
fig.show()


## 11) Pitfalls

        - **Invalid parameters**:
          - $\alpha_i$ must be strictly positive.
          - $x_i$ must be nonnegative integers and must satisfy $\sum_i x_i = n$.

        - **Numerical issues**:
          - PMFs can underflow quickly when $n$ is large. Prefer `logpmf` and compute in log-space.
          - Use `gammaln` / `digamma` rather than `gamma` / factorials.

        - **Combinatorial explosion**:
          - The support size is $\binom{n+K-1}{K-1}$.
          - Exact enumeration (for entropy, CDF, full PMF plots) is only feasible for small $n$ and moderate $K$.

        - **Fitting**:
          - The multinomial is a limiting case ($\alpha_0\to\infty$). In near-multinomial data, MLE may push $\alpha$ very large.
          - Use good initialization and consider bounds / regularization if optimization is unstable.


## 12) Summary

        - The Dirichlet–multinomial is the **posterior predictive** distribution for multinomial counts with a Dirichlet prior.
        - It models **overdispersed** multinomial counts by letting the category probabilities vary across replicates.
        - Mean proportions are $\alpha/\alpha_0$; total concentration $\alpha_0$ controls dispersion.
        - PMF evaluation is stable in log-space via Gamma functions.
        - Exact CDF/entropy require summation over a combinatorial support; for larger problems use Monte Carlo or approximations.